-
Notifications
You must be signed in to change notification settings - Fork 137
WIP: experiment with first class dim objects #1517
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Is this still true? I would think when a user is working they may specify dims by label, so they say Also thinking user can do |
In the current code that is still true. |
We can treat dims as symbols (not sure if that's the term), since in xarray dataset you can't have duplicate dims having different meaning either? But it's a choice not a requirement |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
raise NotImplementedError("Subclass did not implent dim broadcasting") | ||
|
||
|
||
class BasicDim(DimType): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I would perhaps split dim_type / var_type into separate files, this one is already pretty long as is
return Product()(*dims, name=name) | ||
|
||
|
||
def rebase_dim(dim: DimVariable | DimType, *tensors: XTensorVariable) -> DimVariable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the purpose of rebase_dim
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Create a dim from an existing xtensor / get the length at runtime?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a helper for rewrites to avoid infinite loops:
For instance in Elemwise:
@register_lower_xtensor
@node_rewriter(tracks=[XElemwise])
def lower_elemwise(fgraph, node):
assert len(node.outputs) == 1
out_dims = node.outputs[0].dims
out_dims = [rebase_dim(dim, *node.inputs) for dim in out_dims]
# Convert input XTensors to Tensors and align batch dimensions
tensor_inputs = [lower_aligned(inp, out_dims) for inp in node.inputs]
tensor_outs = Elemwise(scalar_op=node.op.scalar_op)(
*tensor_inputs, return_list=True
)
# Convert output Tensors to XTensors
new_outs = [
xtensor_from_tensor(tensor_out, dims=out_dims, check=False)
for tensor_out in tensor_outs
]
return new_outs
The final XTensorFromTensor
op takes the dim variables as inputs. And if we were to use node.outputs[0].dims
for those, the returned graph would still return a reference to the XElemwise
we want to replace, because those dims are variables that use DimFromTensor(XElemwise)
to get the a reference to the dimension length.
Looking good. Do you already have any op that generates its own dims working? |
@@ -96,7 +110,7 @@ def var(x, dim: REDUCE_DIM, *, ddof: int = 0): | |||
x = as_xtensor(x) | |||
x_mean = mean(x, dim) | |||
n = _infer_reduced_size(x, x_mean) | |||
return square(x - x_mean) / (n - ddof) | |||
return square(x - x_mean).mean(dim) / (n - ddof) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just fixed this. Sanity check for myself, it should be sum
right? We then use n - ddof on it
return square(x - x_mean).mean(dim) / (n - ddof) | |
return square(x - x_mean).sum(dim) / (n - ddof) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, yes, I missed that we divide 🤦
I'm currently working on |
Named Dimensions Refactor: Objects Instead of Strings
I'm still working on this, but thought it might be helpful to share what I have so far...
The Key Change
In this version of named-dims, we use objects to represent dimensions instead of plain strings. This allows us to ensure that array axes with shared dimensions are always compatible, eliminating shape errors as long as you stay in dim-land.
The Two Parts of a Dimension
We can think of a dimension as having two components:
Size (or length) - This might be known statically or only at runtime.
Identity - Just because two tensors happen to have the same length doesn't mean they're compatible. The identity decides if two tensors can be combined. Crucially, if two tensors share the same identity, they must always have the same length.
This is similar to vector spaces in math: you can't add a 3D velocity vector to a 3D position vector, even though both are 3D. The mathematical operations care about the meaning of the dimensions, not just their size.
Implementation: Types and Variables
We implement this split using PyTensor's type system:
Type
(an instance ofDimType
) for its identityDimVariable
.The object
foo
itself is aDimVariable
- at runtime, this represents the size of dimensionfoo
.Creating Tensors with Dimensions
The tensor
x
remembers the identity of dimensionfoo
in its type. It doesn't need to store theDimVariable
separately because it can recreate one from the tensor itself when needed:Ensuring Dimension Uniqueness
To prevent shape errors, we need to avoid having two unrelated
DimVariable
s with the same type. Every call topx.dim()
creates a truly unique dimension:We use random UUIDs in the type to guarantee uniqueness.
The size Invariant
For consistent graphs, we maintain this invariant: "During function execution: If two
DimVariable
s have the same type, their runtime values are also the same".This works because
DimVariable
s can only be created in three ways:px.dim()
creates a new unique type, so it can't share its type with anything else.DimVariable
to create the tensor, so length is consistent, or the tensor was user provided. For that case we must a a consistency check of the user input.DimVariable
s - If inputs are consistent, outputs are tooThe main challenge is user input validation - we need to verify that input tensors match their declared dimensions before execution.
Small sidenote:
Unfortunately there is a way users can create two unrelated
DimVariable
objects with the same type:But if we assume that
foo.type()
is a private function (or maybe we can override the call method to make that clearer), that shouldn't be too much of a problem. We just have to make sure we don't do it ourselves when we add new Ops...Derived Dimensions
I think we can do a lot of cool things with derived dimensions, but I'm still working on those.
One simple example that already works is a
ClonedDim
. We don't allow duplicate dimensions in one tensor to simplify indexing and xarray compatibility, but in many cases a user might still need the essentially same dim in a tensor twice (for instance for a covariance matrix). We can use a cloned dimension for that. A cloned dimension always has the same length as its base dimension, but it has a new identity. So for instance:@OriolAbril @ricardoV94
📚 Documentation preview 📚: https://pytensor--1517.org.readthedocs.build/en/1517/